from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import logging
import tqdm
import json
from openai import OpenAI


'''
prompts = [
    "You will be given a story, a question, and the corresponding choices. Please answer this question:\n\nStory: 1 Mila entered the den.\n2 Amelia entered the den.\n3 Lucas entered the hallway.\n4 The peas is in the container.\n5 The container is in the den.\n6 Mila exited the den.\n7 Lucas exited the hallway.\n8 Amelia moved the peas to the drawer.\n9 The drawer is in the den.\nQuestion: Where was the peas at the beginning?\nChoices: ['container', 'drawer']"
]
'''

model = LLM(model=".../data/Qwen2.5-7B/checkpoint-3354")
tok = AutoTokenizer.from_pretrained(
".../data/Qwen2.5-7B/checkpoint-3354"
)

tomi_datadir = ".../data/fixedtomi/test_balanced.jsonl"


category_results = {} # Results per category
category_percents = {}
correctNum = 0
totalNum = 0

logging.basicConfig(
    filename=f"/log_reasoning_budgetforce_tomi.log",      
    level=logging.INFO,         
    format='%(asctime)s - %(levelname)s - %(message)s'
    )

with open(tomi_datadir) as f_in:
    j = 0
    for index, line in tqdm.tqdm(enumerate(f_in), total=200):
        fields = json.loads(line.strip())
        story, question, label, containers, story_type, question_type = fields["story"], fields["question"], fields["answer"], fields["containers"], fields["story_type"], fields["question_type"]

        j+=1
        
        if j > 200:
            break

        prompt = f"""\
Story: [{story}]
Question: [{question}]
Give the answer to this question.
Choose from the following:
{containers[0]}, {containers[1]}.
"""

        stop_token_ids = tok("</think>")["input_ids"]
        sampling_params = SamplingParams(
            max_tokens=4096,
            stop_token_ids=stop_token_ids,
            temperature=0.0,
        )
        o = model.generate(
            prompt,
            sampling_params=sampling_params
        )
        client = OpenAI(
                    api_key="",
                    base_url="",
                )
        
        prompt_evaluate = f"""\
[Question: {question}]

***[Response Answer: {o[0].outputs[0].text}]***

***[Correct Answer: {label}]***

Only based on the ***[Correct Answer]***, judge whether the ***[Response Answer]*** is correct. Output 'True'or 'False' only.  
"""

        completion = client.chat.completions.create(
                    model="deepseek-v3",  
                    messages=[
                        {'role': 'system', 'content': 'You are a helpful assistant.'},
                        {'role': 'user', 'content': prompt_evaluate}
                    ],
                    temperature = 0.0
                )
        print(completion.choices[0].message.content)
        before_graded_answer = completion.choices[0].message.content

        if before_graded_answer == "True":
            before_correct = True
        else:
            before_correct = False
        
        ignore_str = "Wait"

        for i in range(1):
            
            prompt += o[0].outputs[0].text + ignore_str
            sampling_params = SamplingParams(
                max_tokens=4096,
                stop_token_ids=stop_token_ids,
                temperature=0.0,
            )
            o = model.generate(
                prompt,
                sampling_params=sampling_params
            )
        ### Final answer ###
        prompt += o[0].outputs[0].text 
        prompt += "Final Answer"
        sampling_params = SamplingParams(
            max_tokens=4096,
            stop_token_ids=stop_token_ids,
            temperature=0.0,
        )
        o = model.generate(
            prompt,
            sampling_params=sampling_params,
        )
        print("With budget forcing:")
        print(prompt + o[0].outputs[0].text)

       
        prediction = o[0].outputs[0].text

        client = OpenAI(
                    api_key="",
                    base_url="",
                )
        
        prompt_evaluate = f"""\
[Question: {question}]

***[Response Answer: {prediction}]***

***[Correct Answer: {label}]***

Only based on the ***[Correct Answer]***, judge whether the ***[Response Answer]*** is correct. Output 'True'or 'False' only.  
"""

        completion = client.chat.completions.create(
                    model="deepseek-v3",  
                    messages=[
                        {'role': 'system', 'content': 'You are a helpful assistant.'},
                        {'role': 'user', 'content': prompt_evaluate}
                    ],
                    temperature = 0.0
                )
        print(completion.choices[0].message.content)
        graded_answer = completion.choices[0].message.content

        if graded_answer == "True":
            correct = True
            correctNum += 1
        else:
            correct = False
        totalNum += 1


        logging.info(f"Index: {index}")
        logging.info(f"Story: {story}")
        logging.info(f"Question: {question}")
        logging.info(f"Prediction: {prediction}")
        logging.info(f"Label: {label}")
        logging.info(f"**********Before_Correct**********: {before_correct}")
        logging.info(f"**********Correct**********: {correct}")
        logging.info(f"Story_Type: {story_type}")
        logging.info(f"Question_Type: {question_type}")

        temp = category_results.get("count"+question_type, {"correct": 0, "total" : 0})
        if correct:
            temp["correct"] += 1
        temp["total"] += 1
        percent = temp["correct"] / temp["total"]
        category_results["count"+question_type] = temp
        category_percents[question_type] = percent

accuracy = correctNum / totalNum
print(correctNum)
print(totalNum)
print(f"Accuracy: {accuracy*100:.3f}%")


logging.info(f"category_results: {category_results}")
logging.info(f"category_percents: {category_percents}")
logging.info(f"Accuracy: {accuracy*100:.3f}%")



